import numpy as np
import matplotlib.pyplot as plt
from itertools import product
import matplotlib
import seaborn as sns
import brewer2mpl
matplotlib.rcParams['font.family'] = 'serif'
matplotlib.rcParams['mathtext.fontset'] = 'cm'
bmap = brewer2mpl.get_map('Set1', 'qualitative', 7)
colors = bmap.mpl_colors
################################################################################################
# Himmelblau
################################################################################################ 
sample_point = 20
x1 = np.linspace(-np.pi, np.pi, sample_point)
x2 = np.linspace(-np.pi, np.pi, sample_point)

train_X = np.array(list(product(x1,x2)))
train_y = (train_X[:,0]**2+train_X[:,1]-1.5*np.pi)**2 + (train_X[:,0]+train_X[:,1]**2-np.pi)**2

min_y, max_y = np.min(train_y), np.max(train_y)
a, b = 2/(max_y-min_y), 1-2*max_y/(max_y-min_y)

train_y = train_y * a + b  #[-1,1]

###################################################################### target function
X, Y = np.meshgrid(x1, x2)

f = (X**2+Y-1.5*np.pi)**2 + (X+Y**2-np.pi)**2
f = f*a + b

fig=plt.figure()
fig.set_size_inches(10, 10)
ax = plt.axes(projection='3d')
plt.tick_params(labelsize=30)
ax.plot_surface(X, Y, f, cmap='viridis', edgecolor='none', linewidth=2)
ax.set_xlabel(r"$x$",fontsize=50, labelpad=20)
ax.set_ylabel(r"$y$",fontsize=50, labelpad=20)
ax.set_zlabel(r"$f(x,y)$", fontsize=50, labelpad=24, rotation=180)
ax.set_zticks(np.arange(-1,2,1))


plt.savefig('extension/approximation/Himmelblau_function.png', dpi=600)


###################################################################### loss 
sample_nums = 5
fig=plt.figure()
fig.set_size_inches(12, 10)
sns.axes_style("ticks")
plt.tick_params(labelsize=26)
d = 0
ls = ['--','-.', 'dotted','-']
loss_list=[]
loss = np.load('extension/approximation/data/train_loss_2_10_3_101.npy')
loss_list.append(loss.reshape(-1))
loss_list = np.array(loss_list)
average_loss = np.average(loss_list, axis=0)
std_loss = np.std(loss_list, axis=0)
plt.plot(average_loss, label="QNN L=10", lw=6, ls=ls[3], c=colors[0] , alpha=0.8)
plt.legend(prop={'size': 40})
plt.xlabel(r'Iteration',fontsize=40)
plt.ylabel(r'MSE',fontsize=40)
plt.xticks(fontsize=36)
plt.yticks(fontsize=36)
plt.savefig('extension/approximation/Himmelblau_2_10_3_loss.png', dpi=600)



###################################################################### predict 
predict_list = []

predict = np.load('extension/approximation/data/predict_y_2_10_3_101.npy')
predict_list.append(predict)

predict_list = np.array(predict_list)
average_predict = np.average(predict_list, axis=0)
average_predict = average_predict.reshape(20,20)
std_predict = np.std(predict_list, axis=0)

fig=plt.figure()
fig.set_size_inches(10, 10)
ax = plt.axes(projection='3d')
plt.tick_params(labelsize=30)
ax.plot_surface(Y, X, average_predict, cmap='viridis', edgecolor='none')
ax.set_xlabel(r"$x$", fontsize=50, labelpad=20)
ax.set_ylabel(r"$y$", fontsize=50, labelpad=20)
ax.set_zlabel(r"$f_{\mathbf{\theta},L}(x,y)$", fontsize=50, labelpad=24, rotation=0)
# ax.set_zlim(-1,1)

ax.set_zticks(np.arange(-1,2,1))
plt.savefig('extension/approximation/Himmelblau_2_10_3_predict_function.png', dpi=600)
